{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 文本分类简介\n",
    "\n",
    "*Copyright (c) 2022 Institute for Quantum Computing, Baidu Inc. All Rights Reserved.*\n",
    "\n",
    "自然语言处理（Natural Language Processing, NLP）是一种机器学习技术，使计算机具有解释、理解和使用人类语言的能力。现在的各类政企拥有大量的语音和文本数据，这些数据来自各种通信渠道，如电子邮件、文本信息、社交媒体新闻报道、视频、音频等等。他们使用NLP软件来自动处理这些数据，分析信息中的意图或情绪，并实时回应人们的沟通。\n",
    "\n",
    "文本分类任务是 NLP 中的基础任务之一。它对输入的文本预测其类别。新闻标题分类、情感分析等应用背后的技术都是文本分类。\n",
    "\n",
    "在这里，我们使用新闻标题分类为例来展示量子机器学习（Quantum Machine Learning, QML）处理文本分类问题的能力。\n",
    "\n",
    "我们使用房地产行业和汽车行业这两类新闻标题作为数据集，对其进行分类。在这个数据集中，训练集包括400条文本数据，测试集包括100条文本数据。数据的样例如下：\n",
    "\n",
    "- 奔驰GLS怎么样？\n",
    "- 如何评价襄阳房价？\n",
    "- 南宁的城建怎么样？\n",
    "- 保时捷为什么这么贵"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用 QSANN 模型实现新闻标题分类\n",
    "\n",
    "### QSANN 模型简介\n",
    "\n",
    "量子自注意力神经网络（Quantum Self-Attention Neural Networks for Text Classification, QSANN）是一个在监督学习框架下的量子–经典混合算法。它先使用了参数化量子电路（Parameterized Quantum Circuit, PQC）对文本数据进行特征编码，然后再使用自注意力机制进行特征提取，最后使用全连接神经网络处理得到分类结果。\n",
    "\n",
    "总结来说，QSANN 的大致原理如下：\n",
    "\n",
    "1. 将输入文本的每个字映射成对应的参数化量子电路。该电路演化得到的量子态即为这个字对应的特征表示。\n",
    "2. 使用自注意力机制对量子态进行处理，并得到处理后的特征表示。\n",
    "3. 使用全连接神经网络对得到的特征进行处理，并得到预测的分类结果。\n",
    "\n",
    "### 工作流\n",
    "\n",
    "QSANN 是学习类的模型。我们需要先使用数据集对模型进行训练。在训练收敛后，我们便得到了一个训练好的模型，这个模型可以对这类数据进行分类。因此，其工作流如下：\n",
    "\n",
    "1. 制备数据集。\n",
    "2. 使用数据集进行训练，得到训练好的模型。\n",
    "3. 使用该模型对输入的文本进行预测，得到预测结果。\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何使用\n",
    "\n",
    "### 使用模型进行预测\n",
    "\n",
    "这里，我们已经给出了一个训练好的模型，可以直接用于房地产行业和汽车行业的新闻标题分类预测。只需要在 `example.toml` 这个配置文件中进行对应的配置，然后输入命令 `python qsann_classification.py --config example.toml` 即可使用训练好的 QSANN 模型对输入的文本进行预测。\n",
    "\n",
    "### 在线演示\n",
    "\n",
    "这里，我们给出一个在线演示的版本，可以在线进行预测。首先定义配置文件的内容：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_toml = r\"\"\"\n",
    "# 模型配置文件。\n",
    "# 输入当前的任务，可以是 'train' 或者 'test'，分别代表训练和预测。这里我们使用 test，表示我们要进行预测。\n",
    "task = 'test'\n",
    "# 输入要预测的文本。\n",
    "text = '奔驰GLS怎么样？'\n",
    "# 训练好的模型参数文件的文件路径。\n",
    "model_path = 'qsann.pdparams'\n",
    "# 数据集的字典文件路径。\n",
    "vocab_path = 'headlines500/vocab.txt'\n",
    "# 量子电路所包含的量子比特的数量。\n",
    "num_qubits = 6\n",
    "# 自注意力层的层数。\n",
    "num_layers = 1\n",
    "# 词嵌入电路的电路深度。\n",
    "depth_ebd = 1\n",
    "# 注意力机制中 query 电路的电路深度。\n",
    "depth_query = 1\n",
    "# 注意力机制中 key 电路的电路深度。\n",
    "depth_key = 1\n",
    "# 注意力机制中 value 电路的电路深度。\n",
    "depth_value = 1\n",
    "# 对输入文本要预测的类别。\n",
    "classes = ['房地产行业', '汽车行业']\n",
    "\"\"\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来是预测部分的代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输入的文本是：奔驰GLS怎么样？。\n",
      "模型的预测结果是：汽车行业。\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n",
    "\n",
    "import toml\n",
    "from paddle_quantum.qml.qsann import train, inference\n",
    "\n",
    "config = toml.loads(test_toml)\n",
    "task = config.pop('task')\n",
    "if task == 'train':\n",
    "    train(**config)\n",
    "elif task == 'test':\n",
    "    prediction = inference(**config)\n",
    "    text = config['text']\n",
    "    print(f'输入的文本是：{text}。')\n",
    "    print(f'模型的预测结果是：{prediction}。')\n",
    "else:\n",
    "    raise ValueError(\"未知的任务，它可以是'train'或'test'。\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这里，我们只需要修改要配置文件中的 text 的内容，再运行整个代码，就可以快速对其它文本测试。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 注意事项\n",
    "\n",
    "在这里，我们提供的模型是汽车行业和房地产行业的新闻标题文本分类模型。开发者也可以使用自己的数据集来训练对应的模型。\n",
    "\n",
    "### 数据集结构\n",
    "\n",
    "如果想要使用自定义数据集进行训练，只需要按照规则来准备数据集即可。在数据集文件夹中准备 `train.txt` 和 `test.txt`，如果需要验证集的话还有 `dev.txt`。每个文件里使用一行代表一条数据。每行内容包含文本和对应的标签，使用制表符隔开。文本是由空格隔开的文字组成。\n",
    "\n",
    "### 配置文件介绍\n",
    "\n",
    "在 `test.toml` 里有测试所需要的完整的配置文件内容参考。在 `train.toml` 里有训练所需要的完整的配置文件内容参考。使用 `python qsann_classification --config train.toml` 可以对模型进行训练。使用 `python qsann_classification --config test.toml` 可以加载训练好的模型进行测试。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 引用信息\n",
    "\n",
    "```tex\n",
    "@article{li2022quantum,\n",
    "  title={Quantum Self-Attention Neural Networks for Text Classification},\n",
    "  author={Li, Guangxi and Zhao, Xuanqiang and Wang, Xin},\n",
    "  journal={arXiv preprint arXiv:2205.05625},\n",
    "  year={2022}\n",
    "}\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15 (default, Nov 10 2022, 12:46:26) \n[Clang 14.0.6 ]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "49b49097121cb1ab3a8a640b71467d7eda4aacc01fc9ff84d52fcb3bd4007bf1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
